ai-edge-torch-nightly 0.3.0.dev20240914__py3-none-any.whl → 0.3.0.dev20240918__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
  3. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
  4. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
  5. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
  6. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
  7. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
  8. ai_edge_torch/generative/layers/attention.py +8 -4
  9. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -0
  10. ai_edge_torch/generative/layers/unet/model_config.py +2 -0
  11. ai_edge_torch/generative/utilities/converter.py +82 -0
  12. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +6 -0
  13. ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
  14. ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
  15. ai_edge_torch/version.py +1 -1
  16. {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/RECORD +20 -19
  18. {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/LICENSE +0 -0
  19. {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/WHEEL +0 -0
  20. {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.gemma import gemma2
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/gemma2_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_gemma2_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Gemma2 2B model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = gemma2.build_2b_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
- convert_gemma2_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.gemma import gemma
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/gemma_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_gemma_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Gemma 2B model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = gemma.build_2b_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
- convert_gemma_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/openelm_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_openelm_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts OpenELM model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = openelm.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
86
- convert_openelm_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.phi import phi2
23
- from ai_edge_torch.generative.layers import kv_cache
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/phi2_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_phi2_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Phi-2 model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = phi2.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
- convert_phi2_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.smollm import smollm
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/smollm_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_smollm_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts SmolLM model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = smollm.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
86
- convert_smollm_to_tflite(path)
66
+ app.run(main)
@@ -336,6 +336,8 @@ class Diffusion(nn.Module):
336
336
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
337
337
  query_dim=output_channel,
338
338
  cross_dim=config.transformer_cross_attention_dim,
339
+ hidden_dim=output_channel,
340
+ output_dim=output_channel,
339
341
  attention_batch_size=config.transformer_batch_size,
340
342
  normalization_config=config.transformer_norm_config,
341
343
  attention_config=build_attention_config(
@@ -406,6 +408,8 @@ class Diffusion(nn.Module):
406
408
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
407
409
  query_dim=mid_block_channels,
408
410
  cross_dim=config.transformer_cross_attention_dim,
411
+ hidden_dim=mid_block_channels,
412
+ output_dim=mid_block_channels,
409
413
  attention_batch_size=config.transformer_batch_size,
410
414
  normalization_config=config.transformer_norm_config,
411
415
  attention_config=build_attention_config(
@@ -477,6 +481,8 @@ class Diffusion(nn.Module):
477
481
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
478
482
  query_dim=output_channel,
479
483
  cross_dim=config.transformer_cross_attention_dim,
484
+ hidden_dim=output_channel,
485
+ output_dim=output_channel,
480
486
  attention_batch_size=config.transformer_batch_size,
481
487
  normalization_config=config.transformer_norm_config,
482
488
  attention_config=build_attention_config(
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_tiny_llama_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts TinyLlama model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = tiny_llama.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
- convert_tiny_llama_to_tflite(path)
66
+ app.run(main)
@@ -298,6 +298,8 @@ class CrossAttention(nn.Module):
298
298
  batch_size: int,
299
299
  query_dim: int,
300
300
  cross_dim: int,
301
+ hidden_dim: int,
302
+ output_dim: int,
301
303
  config: cfg.AttentionConfig,
302
304
  enable_hlfb: bool,
303
305
  ):
@@ -307,6 +309,8 @@ class CrossAttention(nn.Module):
307
309
  batch_size (int): batch size of the input tensor.
308
310
  query_dim (int): query tensor's dimension.
309
311
  cross_dim (int): cross attention's dimensions, for key and value tensors.
312
+ hidden_dim (int): hidden dimension that q, k, v tensors project to.
313
+ output_dim (int): output tensor's dimension.
310
314
  config (cfg.AttentionConfig): attention specific configurations.
311
315
  enable_hlfb (bool): whether hlfb is enabled or not.
312
316
  """
@@ -314,16 +318,16 @@ class CrossAttention(nn.Module):
314
318
  self.config = config
315
319
  self.n_heads = config.num_heads
316
320
  self.q_projection = nn.Linear(
317
- query_dim, query_dim, bias=config.qkv_use_bias
321
+ query_dim, hidden_dim, bias=config.qkv_use_bias
318
322
  )
319
323
  self.k_projection = nn.Linear(
320
- cross_dim, query_dim, bias=config.qkv_use_bias
324
+ cross_dim, hidden_dim, bias=config.qkv_use_bias
321
325
  )
322
326
  self.v_projection = nn.Linear(
323
- cross_dim, query_dim, bias=config.qkv_use_bias
327
+ cross_dim, hidden_dim, bias=config.qkv_use_bias
324
328
  )
325
329
  self.output_projection = nn.Linear(
326
- query_dim, query_dim, bias=config.output_proj_use_bias
330
+ hidden_dim, output_dim, bias=config.output_proj_use_bias
327
331
  )
328
332
 
329
333
  self.sdpa_func = (
@@ -178,6 +178,8 @@ class CrossAttentionBlock2D(nn.Module):
178
178
  config.attention_batch_size,
179
179
  config.query_dim,
180
180
  config.cross_dim,
181
+ config.hidden_dim,
182
+ config.output_dim,
181
183
  config.attention_config,
182
184
  enable_hlfb=config.enable_hlfb,
183
185
  )
@@ -68,6 +68,8 @@ class AttentionBlock2DConfig:
68
68
  class CrossAttentionBlock2DConfig:
69
69
  query_dim: int
70
70
  cross_dim: int
71
+ hidden_dim: int
72
+ output_dim: int
71
73
  normalization_config: layers_cfg.NormalizationConfig
72
74
  attention_config: layers_cfg.AttentionConfig
73
75
  enable_hlfb: bool = True
@@ -0,0 +1,82 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utility functions for model conversion."""
17
+
18
+ import ai_edge_torch
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.generative.quantize import quant_recipes
21
+ import torch
22
+
23
+
24
+ def convert_to_tflite(
25
+ pytorch_model: torch.nn.Module,
26
+ tflite_path: str,
27
+ prefill_seq_len: int = 512,
28
+ quantize: bool = True,
29
+ ):
30
+ """Converts a nn.Module model to multi-signature tflite model.
31
+
32
+ A PyTorch model will be converted to a tflite model with two signatures:
33
+ "prefill" and "decode".
34
+
35
+ "prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
36
+ sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
37
+ external KV cache as a sample input.
38
+
39
+ "decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
40
+ of shape [1, 1] of the token position, and an external KV cache as a sample
41
+ input.
42
+
43
+ The final tflite model will be exported to tflite_path.
44
+
45
+ Args:
46
+ pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
47
+ tflite_path (str): The tflite file path to export.
48
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
49
+ Defaults to 512.
50
+ quantize (bool, optional): Whether the model should be quanized. Defaults
51
+ to True.
52
+ """
53
+ # Tensors used to trace the model graph during conversion.
54
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
55
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
56
+ decode_token = torch.tensor([[0]], dtype=torch.int)
57
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
58
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
59
+
60
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
61
+ edge_model = (
62
+ ai_edge_torch.signature(
63
+ 'prefill',
64
+ pytorch_model,
65
+ sample_kwargs={
66
+ 'tokens': prefill_tokens,
67
+ 'input_pos': prefill_input_pos,
68
+ 'kv_cache': kv,
69
+ },
70
+ )
71
+ .signature(
72
+ 'decode',
73
+ pytorch_model,
74
+ sample_kwargs={
75
+ 'tokens': decode_token,
76
+ 'input_pos': decode_input_pos,
77
+ 'kv_cache': kv,
78
+ },
79
+ )
80
+ .convert(quant_config=quant_config)
81
+ )
82
+ edge_model.export(tflite_path)
@@ -811,6 +811,8 @@ class DiffusionModelLoader(BaseLoader):
811
811
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
812
812
  query_dim=output_channel,
813
813
  cross_dim=config.transformer_cross_attention_dim,
814
+ hidden_dim=output_channel,
815
+ output_dim=output_channel,
814
816
  normalization_config=config.transformer_norm_config,
815
817
  attention_config=build_attention_config(
816
818
  num_heads=config.transformer_num_attention_heads,
@@ -877,6 +879,8 @@ class DiffusionModelLoader(BaseLoader):
877
879
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
878
880
  query_dim=mid_block_channels,
879
881
  cross_dim=config.transformer_cross_attention_dim,
882
+ hidden_dim=mid_block_channels,
883
+ output_dim=mid_block_channels,
880
884
  normalization_config=config.transformer_norm_config,
881
885
  attention_config=build_attention_config(
882
886
  num_heads=config.transformer_num_attention_heads,
@@ -950,6 +954,8 @@ class DiffusionModelLoader(BaseLoader):
950
954
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
951
955
  query_dim=output_channel,
952
956
  cross_dim=config.transformer_cross_attention_dim,
957
+ hidden_dim=output_channel,
958
+ output_dim=output_channel,
953
959
  normalization_config=config.transformer_norm_config,
954
960
  attention_config=build_attention_config(
955
961
  num_heads=config.transformer_num_attention_heads,
@@ -212,17 +212,25 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
212
212
  # - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
213
213
  @lower(torch.ops.aten.slice_scatter)
214
214
  def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
215
- start = start or 0
216
- end = end or self.type.shape[dim]
215
+ start = start if start is not None else 0
216
+ end = end if end is not None else self.type.shape[dim]
217
+
218
+ start, end = np.clip(
219
+ [start, end], -self.type.shape[dim], self.type.shape[dim]
220
+ )
221
+
217
222
  if start < 0:
218
223
  start = self.type.shape[dim] + start
219
224
  if end < 0:
220
225
  end = self.type.shape[dim] + end
221
226
 
222
- end = start + step * math.ceil((end - start) / step) - (step - 1)
227
+ if end <= start or np.prod(src.type.shape) == 0:
228
+ return self
223
229
 
230
+ end = start + step * math.ceil((end - start) / step) - (step - 1)
224
231
  padding_low = start
225
232
  padding_high = self.type.shape[dim] - end
233
+ interior_padding = step - 1
226
234
 
227
235
  rank = len(self.type.shape)
228
236
  src = stablehlo.pad(
@@ -230,7 +238,9 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
230
238
  utils.splat(0, src.type.element_type, []),
231
239
  edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
232
240
  edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
233
- interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
241
+ interior_padding=[
242
+ interior_padding if i == dim else 0 for i in range(rank)
243
+ ],
234
244
  )
235
245
  pred = np.ones(self.type.shape, dtype=np.bool_)
236
246
  pred[*[
@@ -57,6 +57,7 @@ global_registry.decompositions.update(
57
57
  torch._decomp.get_decompositions([
58
58
  torch.ops.aten.upsample_nearest2d,
59
59
  torch.ops.aten._native_batch_norm_legit.no_stats,
60
+ torch.ops.aten._native_batch_norm_legit_functional,
60
61
  torch.ops.aten._adaptive_avg_pool2d,
61
62
  torch.ops.aten._adaptive_avg_pool3d,
62
63
  torch.ops.aten.grid_sampler_2d,
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240914"
16
+ __version__ = "0.3.0.dev20240918"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240914
3
+ Version: 0.3.0.dev20240918
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
6
- ai_edge_torch/version.py,sha256=rrWwWO1VLdM1khgk2URt5vN4icTeaTqw8CEIsnJRM0E,706
6
+ ai_edge_torch/version.py,sha256=jWg5qA8V0XqgFoqjk0SCsNWPRBeTmfrir9u0bucHYOU,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -39,25 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
39
39
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=bN_dtqi5C_dHpLsvXJ9vCb9OnZ0frLeyYoWBXZYJEqA,3061
43
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=fiFKkEe3TgOdpLnzsCZzIdwvEz0ikxDavQcRGQhlkBY,3053
42
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=09VbyWErOMP9BXGwZpwvqzN5RaOqRigsELfxNRVeWns,2024
43
+ ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=qJKQu6lKuSVhn8JR7KUeInq0u6yqgxEi7hfKCrZrIqY,2019
44
44
  ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
45
45
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
46
46
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=DgBuR1uq4YQWfWiENBxrx7UCVr4Jc5kWCyoi6ii5DTE,3058
47
+ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
48
48
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=kQTJlCDz_DHLRLlVWE0JEpbOjIGAKtxH1fTSc-jn1nU,8498
49
49
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=_tP5ArL0FKiBNoOqN2rG351IzmhNKQmWUfewlcSdKDs,3024
50
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
51
51
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=mGyBI-nORoI-LhZkI4MFAonkUflIX9iimAer_K8jpck,7088
52
52
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=66APmBId5UayZ7SWSO1zxcLiM8TucOMA-fFEHhm61qs,3049
53
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
54
54
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=_nK2DAOiSuxv5o8ip0i-gmhvvjwF5e7Dm3m5VTcsR2M,4276
55
55
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
56
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
57
57
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
58
58
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
59
59
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
60
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
60
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
61
61
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
62
62
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
63
63
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
@@ -75,12 +75,12 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
75
75
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
76
76
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
77
77
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5u6aOiCVahHNCgax5k9a8uhJn9eMzLa19ldscFKNyWo,3083
78
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
79
79
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkdczTI-Lr-1GDg0R2g4SUUGEMTUZ5uY,7023
80
80
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
81
81
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
82
82
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
- ai_edge_torch/generative/layers/attention.py,sha256=37Fua94dQSiBA9Y5XvHxGb5IfN8p8UgNgu5YwM1Rmrw,13057
83
+ ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
84
84
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
85
85
  ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
86
86
  ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
@@ -90,9 +90,9 @@ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvC
90
90
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
91
91
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
92
92
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
93
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZA--ohqmTfXeYQaBP1WpwFOf-TGHZmUMONocPL_hlFc,27244
94
94
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
95
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
95
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=EzF2qpuoW_qBTYO2uuThh4PN0BqF2vXQHgmfJJKVOSg,9244
96
96
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
97
  ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
98
98
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -108,8 +108,9 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TD7dELN5cVw5
108
108
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
109
109
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
110
110
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
111
+ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
111
112
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
112
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
113
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=sMMidBhGxD-0bJw5FYNVMLb7uIre3zszJ1xBAsyeDGQ,35961
113
114
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
114
115
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
115
116
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
@@ -141,13 +142,13 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
141
142
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
142
143
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
143
144
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
144
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=E5j_xHuyDmA9fcgoi6p04zLGV9mFleyXzx6jSBi2wD0,8529
145
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VvB050UCjB17h6-UNtsaqzVF13MGI01fPFkdmmghTj4,8790
145
146
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
146
147
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
147
148
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=RN6BwMHuFj_rFgLCZ6Tu32XHbS2HGjPJeir2nROQ2rA,10517
148
149
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
149
150
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
150
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
151
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
151
152
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
152
153
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
153
154
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -157,8 +158,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
157
158
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
158
159
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
159
160
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
160
- ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
161
- ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/METADATA,sha256=6NayY4sdwm5Z4jmaIhk17MIQ3_plQOiWX_gGnL3KwPQ,1859
162
- ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
163
- ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
164
- ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/RECORD,,
161
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
162
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/METADATA,sha256=dMaIr8Iny84IfNGQGSrtlTGkYlH_mAMmgvGWm5-pkxM,1859
163
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
164
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
165
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/RECORD,,